# src/vol5_k2m_cc/cc_translator.py
"""
Compact-curvature translator:
- Multi-scale LoG
- Auto-polarity (+LoG vs -LoG)
- Positive-tail quantile threshold
- Adaptive coverage (fallback to top-|LoG| selection)
- Deterministic final mask with morphology + largest component
"""

from __future__ import annotations
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple

import numpy as np
from scipy.ndimage import (
    gaussian_laplace,
    binary_opening,
    binary_closing,
    binary_fill_holes,
    label,
)


@dataclass
class CCTConfig:
    # Core
    operator: str = "LoG"
    sigma_list: List[int] = field(default_factory=lambda: [2, 3, 4])
    normalize: str = "zscore"  # or "none"
    threshold: str = "quantile:0.99"  # grammar: "quantile:<q>"
    connectivity: int = 8
    morph_open: int = 1
    morph_close: int = 1
    fill_holes: bool = True
    remove_small_px: int = 0
    keep: str = "largest"  # "largest" or "all"

    # Coverage steering (percent units; e.g., 0.5 -> 0.5%)
    min_coverage_pct: float = 0.0
    target_coverage_pct: float = 0.5

    # Fallbacks
    fallback_topfrac: float = 0.005  # fraction of total pixels (e.g., 0.005 -> 0.5%)
    # If |LoG| is *all zeros*, we create a compact central blob to avoid empties.
    central_blob_when_allzero: bool = True

    # Optional prior (not required; safe to ignore if path not set)
    use_kernel_prior: bool = False
    kernel_prior_beta: float = 0.8
    kernel_prior_path: Optional[str] = None


# --- helpers -----------------------------------------------------------------

def _zscore(x: np.ndarray) -> np.ndarray:
    m = np.nanmean(x)
    s = np.nanstd(x)
    if not np.isfinite(s) or s == 0:
        return x * 0.0
    return (x - m) / s


def _parse_threshold(th: str) -> Tuple[str, float]:
    """
    Parse threshold grammar.
    Returns: (kind, value)
    - ("quantile", q) if th like "quantile:0.99"
    - ("none", 0.0) if not recognized
    """
    if isinstance(th, (int, float)):
        # treat as quantile directly if given numeric (rare)
        q = float(th)
        return ("quantile", q)
    if not isinstance(th, str):
        return ("none", 0.0)
    th = th.strip().lower()
    if th.startswith("quantile:"):
        try:
            q = float(th.split(":", 1)[1])
            return ("quantile", q)
        except Exception:
            return ("none", 0.0)
    return ("none", 0.0)


def _largest_cc(m: np.ndarray, connectivity: int) -> np.ndarray:
    if not m.any():
        return m
    # scipy.ndimage.label uses 1-connectivity by default; build structure for 8-connectivity
    if connectivity == 8:
        structure = np.array([[1, 1, 1],
                              [1, 1, 1],
                              [1, 1, 1]], dtype=int)
    else:  # 4-connectivity
        structure = np.array([[0, 1, 0],
                              [1, 1, 1],
                              [0, 1, 0]], dtype=int)
    lbl, n = label(m, structure=structure)
    if n <= 1:
        return m
    # largest by pixel count
    sizes = np.bincount(lbl.ravel())
    sizes[0] = 0  # background
    k = sizes.argmax()
    return lbl == k


def _center_blob(shape: Tuple[int, int], target_count: int) -> np.ndarray:
    """Deterministic central disk with ~target_count pixels."""
    H, W = shape
    cy, cx = (H - 1) / 2.0, (W - 1) / 2.0
    y, x = np.indices(shape)
    r = np.hypot(x - cx, y - cy)
    # choose radius so that pi*r^2 ~ target_count (approx)
    # area of disk in pixels ~ pi * R^2
    R = max(1.0, np.sqrt(target_count / np.pi))
    return r <= R


def _apply_morph(mask: np.ndarray, cfg: CCTConfig) -> np.ndarray:
    m = mask.copy()
    if cfg.morph_open > 0:
        for _ in range(cfg.morph_open):
            m = binary_opening(m)
    if cfg.morph_close > 0:
        for _ in range(cfg.morph_close):
            m = binary_closing(m)
    if cfg.fill_holes:
        m = binary_fill_holes(m)
    if cfg.keep == "largest":
        m = _largest_cc(m, cfg.connectivity)
    if cfg.remove_small_px > 0 and m.any():
        # prune tiny CCs below a pixel threshold
        # re-label and keep components >= threshold
        if cfg.connectivity == 8:
            structure = np.ones((3, 3), int)
        else:
            structure = np.array([[0, 1, 0],[1, 1, 1],[0, 1, 0]], int)
        lbl, n = label(m, structure=structure)
        keep = np.zeros_like(m, dtype=bool)
        for i in range(1, n + 1):
            comp = (lbl == i)
            if comp.sum() >= cfg.remove_small_px:
                keep |= comp
        m = keep
    return m


# --- main build ---------------------------------------------------------------

def build_mask(
    e0: np.ndarray,
    cfg: CCTConfig,
) -> Tuple[np.ndarray, Dict]:
    """
    Build S⁺ mask from an E0 snapshot.

    Returns (mask_bool, diagnostics_dict)
    """
    diag: Dict = {
        "sigma_list": list(cfg.sigma_list),
        "polarity": "pos",
        "threshold_kind": None,
        "threshold_value": None,
        "used_fallback": False,
        "fallback_reason": "",
        "coverage_pct": 0.0,
    }

    if e0.ndim != 2 or e0.shape[0] != e0.shape[1]:
        return np.zeros_like(e0, dtype=bool), {**diag, "fallback_reason": "bad_shape"}

    L = e0.shape[0]
    x = e0.astype(np.float64, copy=False)

    # Normalization
    if cfg.normalize.lower() == "zscore":
        x = _zscore(x)

    # Multi-scale LoG (scale-normalized: -sigma^2 * Laplacian(Gaussian))
    if len(cfg.sigma_list) == 0:
        return np.zeros_like(x, dtype=bool), {**diag, "fallback_reason": "no_sigma"}
    S = np.zeros_like(x, dtype=np.float64)
    for s in cfg.sigma_list:
        s = max(0.5, float(s))
        # scipy's gaussian_laplace returns ∇^2(G * x); scale-normalize:
        S += - (s ** 2) * gaussian_laplace(x, sigma=s)

    # Optional kernel prior (soft weighting)
    if cfg.use_kernel_prior and cfg.kernel_prior_path:
        try:
            # Expect a (2,L,L) npy path per (gauge,L) inserted by caller, but
            # we keep this safe: if anything goes wrong, skip prior silently.
            import os
            if os.path.exists(cfg.kernel_prior_path):
                K = np.load(cfg.kernel_prior_path, allow_pickle=False)
                if isinstance(K, np.lib.npyio.NpzFile):
                    K = K[K.files[0]]
                if K.ndim == 3 and K.shape[1:] == (L, L):
                    Kx, Ky = K[0], K[1]
                    def sh(a, dx, dy): return np.roll(np.roll(a, dx, axis=1), dy, axis=0)
                    node = (np.abs(Kx) + np.abs(sh(Kx, -1, 0)) +
                            np.abs(Ky) + np.abs(sh(Ky, 0, -1))) / 4.0
                    m = node.max()
                    if m > 0:
                        prior = (node / m) ** float(cfg.kernel_prior_beta)
                        S = S * prior
        except Exception:
            pass  # strictly optional; never fail run on a prior

    # Auto-polarity: pick polarity with larger positive mass
    pos_mass = np.sum(S[S > 0])
    neg_mass = np.sum((-S)[S < 0])
    if neg_mass > pos_mass:
        S = -S
        diag["polarity"] = "neg"
    else:
        diag["polarity"] = "pos"

    # Positive tail and threshold
    kind, q = _parse_threshold(cfg.threshold)
    diag["threshold_kind"] = kind
    diag["threshold_value"] = q

    if kind == "quantile":
        pos_vals = S[S > 0]
        if pos_vals.size > 0:
            thr = np.quantile(pos_vals, q)
            M = S > thr
        else:
            M = np.zeros_like(S, dtype=bool)
    else:
        # no threshold grammar -> treat as empty (trigger fallback)
        M = np.zeros_like(S, dtype=bool)

    # If empty, fallback to top-|LoG| fraction (or central blob if |LoG|==0)
    if not M.any():
        diag["used_fallback"] = True
        absS = np.abs(S)
        flat = absS.ravel()
        total = flat.size
        k = max(1, int(round(float(cfg.fallback_topfrac) * total)))

        if flat.max() == flat.min():
            # |LoG| is all zeros -> optional central blob
            if cfg.central_blob_when_allzero:
                M = _center_blob(S.shape, k)
                diag["fallback_reason"] = "allzero_log_central_blob"
            else:
                # Select an arbitrary consistent top-k (no-op because all equal)
                # but ensure at least one pixel true
                M = np.zeros_like(S, dtype=bool)
                cy, cx = S.shape[0] // 2, S.shape[1] // 2
                M[cy, cx] = True
                diag["fallback_reason"] = "allzero_log_single_pixel"
        else:
            # Normal top-k selection on |LoG|
            from .cc_targetcov import topk_mask_by_score
            M = topk_mask_by_score(absS, k)
            diag["fallback_reason"] = "top_abs_log"

    # Morphology + largest component
    M = _apply_morph(M, cfg)

    # Enforce minimum coverage at the very end (percent units).  The
    # ``min_coverage_pct`` and ``target_coverage_pct`` parameters in
    # ``CCTConfig`` are specified in percent units, whereas the
    # ``mask_with_target_coverage`` helper operates on fractional
    # coverage (0–1).  If the initial mask coverage is below
    # ``min_coverage_pct`` we attempt to enlarge the mask by
    # selecting a quantile threshold on the positive LoG response
    # that yields a coverage within the desired range.  This ensures
    # that SU3 snapshots with extremely low S⁺ coverage (e.g. <0.2%)
    # can be bumped up to a sensible level without affecting SU2
    # anchors whose coverage already exceeds the minimum.
    cov = 100.0 * (M.sum() / float(M.size))
    if cov < float(cfg.min_coverage_pct):
        # positive part of the LoG response used for quantile search
        pos_score = np.where(S > 0, S, 0.0)
        # convert percent parameters to fractional bounds
        min_cov_frac = float(cfg.min_coverage_pct) / 100.0
        target_cov_frac = float(max(cfg.min_coverage_pct, cfg.target_coverage_pct)) / 100.0
        try:
            from .cc_targetcov import mask_with_target_coverage
            # ``mask_with_target_coverage`` returns (mask, quantile_used, cov_frac)
            M2, _, cov_frac = mask_with_target_coverage(
                pos_score,
                target_low=min_cov_frac,
                target_high=target_cov_frac,
            )
        except Exception:
            M2 = None
        if M2 is not None and M2.any():
            # apply morphology to the bumped mask to maintain consistency
            M = _apply_morph(M2, cfg)
            cov = 100.0 * (M.sum() / float(M.size))
            diag["used_fallback"] = True
            # append bump reason for diagnostics
            diag["fallback_reason"] = (diag.get("fallback_reason", "") + "|bumped_min").strip("|")

    # If the coverage is still below the minimum after any quantile-based
    # adjustment, fall back to morphological dilation.  This step
    # iteratively dilates the mask (8‑connected) until the desired
    # minimum coverage is reached.  Dilating only occurs when
    # ``min_coverage_pct`` is positive and the current coverage is below
    # that threshold.  SU2 anchors are unaffected because their masks
    # already exceed the minimum.  The number of dilation steps is
    # capped to avoid runaway growth.
    if cov < float(cfg.min_coverage_pct):
        try:
            from scipy.ndimage import binary_dilation
        except Exception:
            binary_dilation = None  # type: ignore
        if binary_dilation is not None:
            # prepare a copy so we don't modify M in-place until finished
            m_dil = M.copy()
            h, w = m_dil.shape
            # 8-connectivity structure
            struct = np.ones((3, 3), dtype=bool)
            steps = 0
            max_steps = 8  # safety cap
            while cov < float(cfg.min_coverage_pct) and steps < max_steps:
                m_dil = binary_dilation(m_dil, structure=struct)
                cov = 100.0 * (m_dil.sum() / float(m_dil.size))
                steps += 1
            if m_dil.any():
                M = m_dil
                diag["used_fallback"] = True
                fr = diag.get("fallback_reason", "")
                diag["fallback_reason"] = (fr + "|dilated_min").strip("|")
    diag["coverage_pct"] = float(cov)
    return M.astype(bool, copy=False), diag
